import math
import os
import sys
from tqdm import tqdm
import re
sys.path.append("./")
from utils.logging_utils import setup_logger_to_stdout
import json
import demjson3
import ast

logger = setup_logger_to_stdout()

class Base_Model_PRE_PROCESS:
    def __init__(self) -> None:
        self.action_types = [
            "CLICK", "LONG_CLICK", "TYPE", "SCROLL", "PRESS_BACK", "PRESS_HOME", "OPENAPP", "WAIT", "COMPLETE", "ENTER", "PRESS_APPSELECT", "INCOMPLETE", "MOVETO", "PRESS_SPACE", "TOTAL"
        ]
        self.action_stats = {
            action: {
                'count': 0, 'success': 0, 'type': 0, 'TMR': 0, 'AMR': 0
            } for action in self.action_types
        }
        self.action_stats['TSR'] = 0
        self.results = []
        
    def get_action_type(self, action):
        if action.startswith("CLICK"):
            return 1
        elif action.startswith("LONG_CLICK"):
            return 2
        elif action.startswith("TYPE"):
            return 3
        elif action.startswith("SCROLL"):
            return 4
        elif action.startswith("PRESS_BACK"):
            return 5
        elif action.startswith("PRESS_HOME"):
            return 6    
        elif action.startswith("OPENAPP"):
            return 7
        elif action.startswith("WAIT"):
            return 8
        elif action.startswith("COMPLETE"):
            return 9
        elif action.startswith("ENTER"):
            return 10
        elif action.startswith("PRESS_APPSELECT"):
            return 11
        elif action.startswith("INCOMPLETE"):
            return 12
        elif action.startswith("MOVETO"):
            return 13
        elif action.startswith("PRESS_SPACE"):
            return 14
        else: 
            return 0
        
    def extract_thought(self, output_text):
        return ""
    
    def calculate_TSR(self, res):
        from collections import defaultdict
        grouped = defaultdict(list)
        for sample in res:
            grouped[sample['episode_id']].append(sample)
        total_episodes = len(grouped)
        successful_episodes = 0

        for episode_id, samples in grouped.items():
            if all(sample['is_success'] for sample in samples):
                successful_episodes += 1
        
        success_rate = successful_episodes / total_episodes if total_episodes > 0 else 0
        return success_rate

    def Is_action_type(self, raw_action1, raw_action2):
        action1_type = self.get_action_type(raw_action1)
        action2_type = self.get_action_type(raw_action2)
        if action1_type != action2_type:
            return False
        else:
            return True
    
    def post_process_res(self, args):
        if "PRESS" not in self.action_stats:
            self.action_stats["PRESS"] = {}

        for key in self.action_stats["PRESS_BACK"]:
            self.action_stats["PRESS"][key] = self.action_stats["PRESS_BACK"].get(key, 0) + self.action_stats["PRESS_HOME"].get(key, 0) + self.action_stats['PRESS_SPACE'].get(key, 0)
            self.action_stats["CLICK"][key] += self.action_stats["LONG_CLICK"][key] 
        
        for key_to_delete in ["LONG_CLICK", "PRESS_BACK", "PRESS_HOME", "PRESS_SPACE"]:
            self.action_stats.pop(key_to_delete, None)
        
        for action_type in self.action_stats:
            if action_type == 'TSR':
                continue
            count = self.action_stats[action_type]['count']
            if count != 0:
                self.action_stats[action_type]['TMR'] = self.action_stats[action_type]['type'] / count
                self.action_stats[action_type]['AMR'] = self.action_stats[action_type]['success'] / count
            else:
                self.action_stats[action_type]['TMR'] = 0
                self.action_stats[action_type]['AMR'] = 0
               
        self.action_stats['TSR'] = self.calculate_TSR(self.results)
        for action_type, stats in self.action_stats.items():
            logger.info(f"[{action_type}] => {stats}")
        self.save_results(result_path=args.result_path)
    
    def save_results(self, result_path):
        """ Save the results to a specified path """
        if result_path:
            results_dir = os.path.dirname(result_path)
            if not os.path.exists(results_dir):
                os.makedirs(results_dir)

            summary_data = {
                "detailed_results": self.results,
                "summary": self.action_stats,
            }

            with open(result_path, 'w') as f:
                json.dump(summary_data, f, indent=2, default=str)
            
            logger.info(f"\nResults have been saved in: {result_path}")


class OS_ATLAS_RES_PRE_PROCESS(Base_Model_PRE_PROCESS):
    def __init__(self) -> None:
        super().__init__()

    def extract_action(self, output_text):
        prefix1 = 'action:'
        prefix2 = 'actions:'
            
        if prefix1 in output_text:
            start_index = output_text.find(prefix1) + len(prefix1)
        elif prefix2 in output_text:
            start_index = output_text.find(prefix2) + len(prefix2)
        else:
            start_index = 0  
        

        raw_action = output_text[start_index:]
      
        action_lines = [line.strip() for line in raw_action.split('\n') if line.strip()]
        if action_lines:
            action = action_lines[0]  
        else:
            action = ""  
        return action

    def extract_coordinates(self, action):
        if "[[" in action:
            start = action.find("[[") + 2
            end = action.find("]]")
        elif "<point>" in action:
            start = action.find("[") + 1
            end = action.find("]")
        else:
            return 0.0, 0.0  
            
        try:
            coordinates = action[start:end].split(",")
            return float(coordinates[0]), float(coordinates[1])
        except:
            return 0.0, 0.0  
        
    def extract_text(self, action):
        start = action.find("[") + 1
        end = action.find("]")
        text = action[start:end]
        return text

    def Is_action_success(self, preds, gt, image_size, model_name, dataset_name, bbox):
        
        ## action space {CLICK, LONG_CLICK, TYPE, SCROLL, PRESS_BACK, PRESS_HOME, OPEN_APP, WAIT, COMPLETE, ENTER, PRESS_APPSELECT, INCOMPLETE
        action1_type = self.get_action_type(preds)
        action2_type = self.get_action_type(gt)

        if action1_type != action2_type:
            return False
        
        elif action1_type == 1 or action1_type == 2 or action1_type == 13:
            
            if dataset_name in ['AndroidControl', 'AITZ', 'OmniAct', 'GUIAct']:
                x1, y1 = self.extract_coordinates(preds)
                x2, y2 = self.extract_coordinates(gt)
                if "UI-TARS-1.5-7B" in model_name: 
                    x1, y1 = x1 / image_size[0] * 1000, y1 / image_size[1] * 1000
                    x2, y2 = x2 / image_size[0] * 1000, y1 / image_size[1] * 1000

                dx = abs(x1 - x2) 
                dy = abs(y1 - y2)

                distance = math.sqrt(dx ** 2 + dy ** 2)
                if distance > 140:
                    return False
                else:
                    return True
            else:
                x, y = self.extract_coordinates(preds)
                if "UI-TARS-1.5-7B" in model_name: 
                    x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                coor_x1, coor_y1, coor_x2, coor_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                if x < coor_x1 or x > coor_x2 or y < coor_y1 or y > coor_y2:
                    return False
                else:
                    return True
        
        elif action1_type == 3 or action1_type == 4 or action1_type==7:
            text1 = self.extract_text(preds)
            text2 = self.extract_text(gt)
            # if action1_type == 4:
            #     maps = {'left': 'right', 'right': 'left', 'up': 'down', 'down': "up"}
            #     text2 = maps[text2.lower()].upper()
            if text1 != text2:
                return False
        return True
    
    
    def _res_statistics(self, args, allPredResults):
        self.results = allPredResults
        self.save_results(f'/Agent_Scankit/results/GUIOdyssey/{args.model_name}_{args.dataset_type}.json') 
        self.results = []
        for i, record in tqdm(enumerate(allPredResults), total=len(allPredResults), desc=f"Analyze {args.dataset_path}"):
            preds_action = record['predicted_action']
            check_action = record['predicted_action_type']
            gt_action_type = record['action_type']
            ground_truth = record['real_action']
            image_size = record['image_size']
            bbox = None
            if 'bbox' in record.keys():
                bbox = record['bbox']
           
            for i in range(1, len(self.action_stats)-1):
                if gt_action_type == i:
                    if gt_action_type in [11, 12]:
                        continue
                    self.action_stats['TOTAL']['count'] += 1
                    self.action_stats[self.action_types[i-1]]['count'] += 1
                    if check_action == 0:
                        record["invalid_action"] = True
                        self.results.append(record)
                        continue
                    if self.Is_action_type(preds_action, ground_truth):
                        self.action_stats["TOTAL"]['type'] += 1
                        record["is_type_match"] = True
                        self.action_stats[self.action_types[i-1]]['type'] += 1
                        if self.Is_action_success(preds_action, ground_truth, image_size, args.model_name, record.get("dataset_name"), bbox):
                            self.action_stats["TOTAL"]['success'] += 1
                            record["is_success"] = True
                            self.action_stats[self.action_types[i-1]]['success'] += 1
                    break
            self.results.append(record)

        self.post_process_res(args)

class Qwen2_VL_RES_PREPROCESS(OS_ATLAS_RES_PRE_PROCESS):
    def __init__(self) -> None:
        super().__init__()


    
class UI_TARS_RES_PREPROCESS(OS_ATLAS_RES_PRE_PROCESS):
    def __init__(self) -> None:
        super().__init__()

    def extract_text(self, action):
        pattern = r"=\s*'(.*?)'|=\s*\"(.*?)\"|=\s*([^\s,)]+)"
        m = re.search(pattern, action)
        if not m:
            return ""
        val = next(g for g in m.groups() if g is not None)
      
        return val.replace("\\'", "'").replace('\\"', '"')
          
    def extract_action(self, output_text):
        prefix = 'Action:'
        if prefix in output_text:
            start_index = output_text.find(prefix) + len(prefix)
        else:
            start_index = 0  
        

        raw_action = output_text[start_index:]
      
        action_lines = [line.strip() for line in raw_action.split('\n') if line.strip()]
        if action_lines:
            action = action_lines[0]  
        else:
            action = ""  
        return action
    
    def extract_thought(self, output_text):
        prefix = "Thought:"
        if prefix in output_text:
            start_index = output_text.find(prefix) + len(prefix)
        else:
            start_index = 0  
      
        raw_thought = output_text[start_index:]
      
        thought_lines = [line.strip() for line in raw_thought.split('\n') if line.strip()]
        if thought_lines:
            thought = thought_lines[0]  
        else:
            thought = ""  
        return thought

    def get_action_type(self, action):
        if action.startswith("click"):
            return 1
        elif action.startswith("long_press"):
            return 2
        elif action.startswith("type"):
            return 3
        elif action.startswith("scroll"):
            return 4
        elif action.startswith("press_back"):
            return 5
        elif action.startswith("press_home"):
            return 6    
        elif action.startswith("open_app"):
            return 7
        elif action.startswith("wait"):
            return 8
        elif action.startswith("finish"):
            return 9
        elif action.startswith("enter"):
            return 10
        elif action.startswith("press_appselect"):
            return 11
        elif action.startswith("incomplete"):
            return 12
        elif action.startswith("moveto"):
            return 13
        elif action.startswith("press_space"):
            return 14
        else: 
            return 0
    
    def extract_coordinates(self, action):
        import re
        match = re.search(r"\(\s*([-+]?\d*\.?\d+)\s*,\s*([-+]?\d*\.?\d+)\s*\)", action)
        if match:
            coords = [float(match.group(1)), float(match.group(2))]
            return coords[0], coords[1]
        else:
            return 0.0, 0.0

   
        

class GUI_R1_RES_PREPROCESS(OS_ATLAS_RES_PRE_PROCESS):
    def __init__(self) -> None:
        super().__init__()

    def extract_coordinates(self, content):
        answer_tag_pattern = r'<answer>(.*?)</answer>'
        bbox_pattern = r'\{.*\[(\d+),\s*(\d+)]\s*.*\}'
        content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
        try:
            if content_answer_match:
                content_answer = content_answer_match.group(1).strip()
                coord_match = re.search(bbox_pattern, content_answer)
                if coord_match:
                    coord = [int(coord_match.group(1)), int(coord_match.group(2))]
                    return coord, True
            else:
                coord_pattern = r'\{.*\((\d+),\s*(\d+))\s*.*\}'
                coord_match = re.search(coord_pattern, content)
                if coord_match:
                    coord = [int(coord_match.group(1)), int(coord_match.group(2))]
                    return coord, True
            return [0, 0, 0, 0], False
        except:
            return [0, 0, 0, 0], False
        
    def extract_action(self, content):
        match = re.search(r'<answer>(.*?)</answer>', content)
        if match:
            preds = match.group(1)
        else:
            preds = ""
        return preds
        

    def extract_text(self, content):
        answer_tag_pattern = r'<answer>(.*?)</answer>'
        action_pattern = r"'input_text':\s*'(.*?)'"
        content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
        if content_answer_match:
            content_answer = content_answer_match.group(1).strip()
            action_match = re.search(action_pattern, content_answer)
            if action_match:
                return action_match.group(1)
        return "no input text"
    
    def extract_thought(self, output_text):
        match = re.search(r'<think>(.*?)</think>', output_text)
        if match:
            thought = match.group(1)
        else:
            thought = ""
        return thought
    
    def get_action_type(self, action):
        import ast
        data = ast.literal_eval(action)
        action_type = data[0]['action']
        action_map = {
            'click': 1,
            'long_press': 2,
            'type': 3,
            'scroll': 4,
            'press_back': 5,
            'press_home': 6,
            'open_app': 7,
            'wait': 8,
            'complete': 9,
            'enter': 10,
            'press_appselect': 11,
            'impossible': 12,
            'moveto': 13,
            'press_pgdn': 14
        }
        if action_type:
            for key in action_map:
                if action_type.startswith(key):
                    return action_map[key]    
        return 0
              
    def Is_action_success(self, preds, gt, image_size, check_action, dataset_name, bbox):
        
        ## action space {CLICK, LONG_CLICK, TYPE, SCROLL, PRESS_BACK, PRESS_HOME, OPEN_APP, WAIT, COMPLETE
        action1_type = check_action
        if action1_type == 1 or action1_type == 2 or action1_type == 13:
            if dataset_name in ['AndroidControl', 'AITZ', "OmniAct", 'GUIAct']:
                x1, y1 = self.extract_coordinates('<answer>'+preds+'</answer>')[0]
                x1, y1 = x1 / image_size[0]*1000, y1/image_size[1]*1000

                x2, y2 = self.extract_coordinates('<answer>'+gt+'</answer>')[0]
                x2, y2 = x2 / image_size[0]*1000, y2/image_size[1]*1000
            
                dx = abs(x1 - x2) 
                dy = abs(y1 - y2)

                distance = math.sqrt(dx ** 2 + dy ** 2)
                if distance > 140:
                    return False
                else:
                    return True
            else:
                x, y = self.extract_coordinates('<answer>'+preds+'</answer>')[0]
                x, y = x / image_size[0]*1000, y/image_size[1]*1000
                coor_x1, coor_y1, coor_x2, coor_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                if x < coor_x1 or x > coor_x2 or y < coor_y1 or y > coor_y2:
                    return False
                else:
                    return True
      
        elif action1_type == 3 or action1_type == 4 or action1_type==7:
            text1 = self.extract_text('<answer>'+preds+'</answer>')
            text2 = self.extract_text('<answer>'+gt+'</answer>')
            if action1_type == 4:
                map_direction = {"left": "right", "right": "left", "up": "down", "down": "up"}
                text2 = map_direction[text2]
            if text1 != text2:
                return False
               
        return True
    
    def _res_statistics(self, args, allPredResults):
        self.results = allPredResults
        self.save_results(f'/Agent_ScanKit/results/visual_mask/{args.model_name}_{args.dataset_type}.json') 
        self.results = []
        for i, record in tqdm(enumerate(allPredResults), total=len(allPredResults), desc=f"Analyze {args.dataset_path}"):
            preds_action = record['predicted_action']
            check_action = record['predicted_action_type']
            gt_action_type = record['action_type']
            ground_truth = record['real_action']
            image_size = record['image_size']
            bbox = None
            if record['bbox'] is not None:
                bbox = record['bbox']
            
            for index in range(1, len(self.action_stats)-1):
                if gt_action_type == index:
                    if gt_action_type == 11:
                        continue
                    self.action_stats['TOTAL']['count'] += 1
                    self.action_stats[self.action_types[index-1]]['count'] += 1
                    if check_action == 0:
                        record["invalid_action"] = True
                        self.results.append(record)
                        continue
                    if check_action == gt_action_type:
                        self.action_stats["TOTAL"]['type'] += 1
                        record["is_type_match"] = True
                        self.action_stats[self.action_types[index-1]]['type'] += 1#ven
                        if self.Is_action_success(preds_action, ground_truth, image_size, check_action, record['dataset_name'], bbox):
                            self.action_stats["TOTAL"]['success'] += 1
                            record["is_success"] = True
                            self.action_stats[self.action_types[index-1]]['success'] += 1
                    break
            self.results.append(record)
        self.post_process_res(args)

class Agent_CPM_RES_PREPROCESS(GUI_R1_RES_PREPROCESS):
    def __init__(self) -> None:
        super().__init__()

    def extract_coordinates(self, content):
        try:
            x, y = ast.literal_eval(content)['POINT']
        except:
            logger.info("extract coordinates failure")
            return (0, 0)
        return (x, y)
    
    def extract_action(self, content):
        return content
    
    def extract_thought(self, output_text):
        try:
            thought = ast.literal_eval(output_text)['thought']
        except:
            logger.info("extract thought failure")
            return ""
        return thought

    def extract_text(self, content, key):
        try:
            return ast.literal_eval(content)[key]
        except:
            logger.info("extract text failure")
    
    def get_action_type(self, action):
        try:
            preds = ast.literal_eval(action)
            keys = list(preds.keys())

            if len(keys) < 2:
                return 0

            action_type = keys[1]

            if action_type == 'POINT':
                if len(keys) == 3:
                    third_key = keys[2]
                    if third_key == 'duration':
                        return 2  # long_press
                    elif third_key == 'to':
                        return 4  # scroll
                return 1  # click

            elif action_type == 'PRESS':
                press_target = preds['PRESS']
                if press_target == 'HOME':
                    return 6  # press_home
                elif press_target == 'BACK':
                    return 5  # press_back
                elif press_target == 'ENTER':
                    return 10
                elif press_target == 'SPACE':
                    return 14

            elif action_type == 'TYPE':
                return 3  # type
            
            elif action_type == 'open_app':
                return 7
            
            elif action_type == 'duration':
                return 8  # wait
            elif action_type == 'STATUS':
                if preds['STATUS'] == 'finish':
                    return 9
                elif preds['STATUS'] == 'impossible':
                    return 11
                else:
                    return 0
            else:
                return 0
        except Exception as e:
            logger.info(f"extract action type failure: {e}")
            return 0

    def Is_action_success(self, preds, gt, image_size, action_type, dataset_name, bbox):
        if action_type in [1, 2]:
            if dataset_name in ['AndroidControl', 'AITZ', 'OmniAct', 'GUIAct']:
                x1, y1 = self.extract_coordinates(preds)
                x2, y2 = self.extract_coordinates(gt)
                x2 = x2 / image_size[0] * 1000
                y2 = y2 / image_size[1] * 1000
                distance = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
                return distance <= 140
            else:
                x, y = self.extract_coordinates(preds)
                coor_x1, coor_y1, coor_x2, coor_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                if x < coor_x1 or x > coor_x2 or y < coor_y1 or y > coor_y2:
                    return False
                else:
                    return True

        elif action_type == 4:
            scroll1 = self.extract_text(preds, 'to')
            scroll2 = self.extract_text(gt, 'to')
            return scroll1 == scroll2

        elif action_type == 3:
            text1 = self.extract_text(preds, 'TYPE')
            text2 = self.extract_text(gt, 'TYPE')
            return text1 == text2

        elif action_type in [5, 6, 10]:
            key1 = self.extract_text(preds, 'PRESS')
            key2 = self.extract_text(gt, 'PRESS')
            return key1 == key2
        
        elif action_type == 7:
            key1 = self.extract_text(preds, 'open_app')
            key2 = self.extract_text(gt, 'open_app')

        elif action_type == 9:
            status1 = self.extract_text(preds, 'STATUS')
            status2 = self.extract_text(gt, 'STATUS')
            return status1 == status2

        elif action_type == 8 and action_type == 10:
            return True
        return True
    
    def _res_statistics(self, args, allPredResults):
        return super()._res_statistics(args, allPredResults)

      
    
class AGUVIS_RES_PREPROCESS(Agent_CPM_RES_PREPROCESS):
    def __init__(self) -> None:
        super().__init__()

    def extract_action(self, content):
        return self.mapping_actions(content)

    def extract_thought(self, output_text):
        try:
            content = self.mapping_actions(output_text)
            thought = content['Thought']
        except:
            logger.info("extract thought failure")
            return ""
        return thought
    
    def extract_text(self, content, key):
        try:
            return content[key]
        except:
            logger.info("extract text failure")
    
    def extract_coordinates(self, content):
        try:
            x, y = content['POINT']
        except:
            logger.info("extract coordinates failure")
            return (0, 0)
        return (x, y)
    
    def get_action_type(self, pred):
        try:
            keys = list(pred.keys())
            if not keys:
                return 0

            action_type = keys[0]

            if action_type == 'open_app':
                return 7
            elif action_type == 'POINT':
                if 'to' in keys:
                    return 4
                duration = pred.get('duration', 0)
                if duration == 200:
                    return 2
                elif duration == 2000:
                    return 13
                else:
                    return 1
            elif action_type == 'TYPE':
                return 3
            elif action_type == 'PRESS':
                press_value = pred.get('PRESS', '')
                if press_value == 'HOME':
                    return 6
                elif press_value == 'BACK':
                    return 5
                elif press_value == 'ENTER':
                    return 10
                elif press_value == 'APPSELECT':
                    return 12
                elif press_value == 'SPACE':
                    return 13
            elif action_type == 'duration':
                return 8
            elif action_type == 'STATUS':
                s_value = pred.get('STATUS')
                if s_value == 'finish':
                    return 9
                else:
                    return 12
            return 0 
        except Exception as e:
            logger.info(f"[get_action_type] Failed to extract: {e}")
            return 0
        
    def Is_action_success(self, preds, gt, image_size, action_type, dataset_name, bbox):
        if action_type in [1, 2, 13]:
            if dataset_name in ['AndroidControl', 'AITZ', 'OmniAct', 'GUIAct']:
                x1, y1 = self.extract_coordinates(preds)
                x2, y2 = self.extract_coordinates(gt)
                x2, y2 = x2/image_size[0]*1000, y2/image_size[1]*1000
                distance = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
                return distance <= 140
            else:
                x, y = self.extract_coordinates(preds)
                coor_x1, coor_y1, coor_x2, coor_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                if x < coor_x1 or x > coor_x2 or y < coor_y1 or y > coor_y2:
                    return False
                else:
                    return True

        elif action_type == 4:
            scroll1 = self.extract_text(preds, 'to')
            scroll2 = self.extract_text(gt, 'to')
            map_direction = {"left": "right", "right": "left", "up": "down", "down": "up"}
            scroll2 = map_direction[scroll2]
            return scroll1 == scroll2

        elif action_type == 3:
            text1 = self.extract_text(preds, 'TYPE')
            text2 = self.extract_text(gt, 'TYPE')
            return text1 == text2

        elif action_type in [5, 6, 10, 14]:
            key1 = self.extract_text(preds, 'PRESS')
            key2 = self.extract_text(gt, 'PRESS')
            return key1 == key2
        
        elif action_type == 7:
            key1 = self.extract_text(preds, 'open_app')
            key2 = self.extract_text(gt, 'open_app')

        elif action_type == 9:
            status1 = self.extract_text(preds, 'STATUS')
            status2 = self.extract_text(gt, 'STATUS')
            return status1 == status2

        elif action_type == 8 and action_type == 10:
            return True
        return True

    def mapping_actions(self, pred: str) -> str:
        """
        Mapping the string from aguvis model into minicpm Action space.

        Args:
            episode (dict): The episode dict, containing all information and a prediction string.

        Returns:
            the episode whose prediction string is mapped into minicpm Action space.

        In practice the model will output unstable strings. We only handle those stable cases.
        """

        FAIL_PARSE = {
            "STATUS": "FAIL"
        }
        thought = ""
        try:
            match = re.search(r"Action:\s*(.+)", pred)
            if match:
                thought:str = match.group(1)
        except:
            logger.info('extract thought failure')
        action:str = pred.split('\n')[-1].strip()

        platform = action.split('.')[0]

        function = action[len(platform) + 1 :]
        if platform == "pyautogui":
            # press(keys=['enter'])
            if function.startswith("click"):
                # deal with click function.
                try:
                    matches = re.findall(r"[-+]?\d*\.\d+|\d+", function)
                    x,y = matches
                    x = round(float(x) * 1000)
                    y = round(float(y) * 1000)

                    pred = {
                        "POINT": [x, y],
                        "duration": 200,
                        "STATUS": "continue",
                        "Thought": thought
                    }
                except Exception as e:
                    logger.info(f"Failed to parse POINT ACTION {function}: {e}")
                    pred = FAIL_PARSE
            
            elif function.startswith("moveto"):
                # deal with click function.
                try:
                    matches = re.findall(r"[-+]?\d*\.\d+|\d+", function)
                    x,y = matches
                    x = round(float(x) * 1000)
                    y = round(float(y) * 1000)

                    pred = {
                        "POINT": [x, y],
                        "duration": 2000,
                        "STATUS": "continue",
                        "Thought": thought
                    }
                except Exception as e:
                    logger.info(f"Failed to parse POINT ACTION {function}: {e}")
                    pred = FAIL_PARSE

            elif function.startswith("write"):
                # deal with type action.
                try:

                    pattern = r'message=(["\'])(.*?)\1'
                    match = re.search(pattern, function)

                    text = match.group(2)

                    pred = {
                        "TYPE": text,
                        "duration": 200,
                        "STATUS": "continue",
                        "Thought": thought
                    }

                except Exception as e:
                    logger.info(f"Failed to parse TYPE ACTION {function}: {e}")
                    pred = FAIL_PARSE

            elif function.startswith("scroll"):
                # deal with scroll up/down
                try:
                    pattern = r'scroll\(page=([-+]?\d*\.\d+|\d+)\)'
                    match = re.match(pattern, function)

                    value = float(match.group(1))

                    pred = {
                        "POINT": [500, 500],
                        "to": "up" if value > 0 else "down",
                        "duration": 200,
                        "STATUS": "continue",
                        "Thought": thought
                    }
                except Exception as e:
                    logger.info(f"Failed to parse MOVE_TO ACTION {function}: {e}")
                    pred = FAIL_PARSE

            elif function.startswith("hscroll"):
                # deal with scroll left/right
                try:
                    pattern = r'hscroll\(page=([-+]?\d*\.\d+|\d+)\)'
                    match = re.match(pattern, function)

                    value = float(match.group(1))

                    pred = {
                        "POINT": [500, 500],
                        "to": "left" if value < 0 else "right",
                        "duration": 200,
                        "STATUS": "continue",
                        "Thought": thought
                    }

                except Exception as e:
                    logger.info(f"Failed to parse MOVE_TO ACTION {function}: {e}")
                    pred = FAIL_PARSE
            elif "enter" in function:
                pred = {
                    "PRESS": "ENTER",
                    "duration": 200,
                    "STATUS": "continue",
                    "Thought": thought
                }
            elif function.startswith("space"):
                pred = {
                    "PRESS": "SPACE",
                    "duration": 200,
                    "STATUS": "continue",
                    "Thought": thought
                }
            else:
                logger.info(f"Unrecognize action in {platform}: {function}")
                pred = FAIL_PARSE

        elif platform == "mobile":

            if function.startswith("back"):
                # deal with back action.
                pred = {
                    "PRESS": "BACK",
                    "duration": 200,
                    "STATUS": "continue",
                    "Thought": thought
                }
            elif function.startswith("home"):
                # deal with home action.
                pred = {
                    "PRESS": "HOME",
                    "duration": 200,
                    "STATUS": "continue",
                    "Thought": thought
                }
            elif function.startswith("terminate"):
                # deal with terminate action.
                if 'success' in action:
                    pred = {
                        "STATUS": "finish",
                        "Thought": thought
                    }
                else:
                    pred = {
                        "STATUS": "interrupt",
                        "Thought": thought
                    }

            elif function.startswith("open_app"):
                # deal with open_app action. This action will not be accepted by our evaluation.
                try:
                    match = re.search(r"app_name='([^']+)'", function)
                    app_name = match.group(1)

                    pred = {
                        "open_app": app_name,
                        "duration": 200,
                        "STATUS": "continue",
                        "Thought": thought
                    }

                except Exception as e:
                    logger.info(f"Failed to parse open_app ACTION {function}: {e}")
                    pred = FAIL_PARSE

            elif function.startswith("wait"):
                # deal with wait action.
                pred = {
                    "duration": 3000,
                    "STATUS": "continue",
                    "Thought": thought
                }

            elif function.startswith("long_press"):
                # deal with long_press action.
                try:
                    matches = re.findall(r"[-+]?\d*\.\d+|\d+", action)
                    x,y = matches
                    x = round(float(x) * 1000)
                    y = round(float(y) * 1000)

                    pred = {
                        "POINT": [x, y],
                        "duration": 1000,
                        "STATUS": "continue",
                        "Thought": thought
                    }

                except Exception as e:
                    logger.info(f"Failed to parse LONG_PRESS ACTION {function}: {e}")
                    pred = FAIL_PARSE

            else:
                logger.info(f"Unrecognize action in {platform}: {function}")
                pred = FAIL_PARSE

        else:
            # Any other unstable output will be informed.
            # Unrecognize action in pyautogui: press(keys=['enter'])
            logger.info(f'Unrecognize output: {repr(pred)}.')
            pred = FAIL_PARSE

        return pred
    
class OS_GENESIS_PREPROCESS(AGUVIS_RES_PREPROCESS):
    def __init__(self, policy_name) -> None:
        super().__init__()
        self.policy_name = policy_name

    def extract_action(self, content):
        return self.mapping_actions(content)
    
    def extract_thought(self, output_text):
        match = re.search(r"Low-level thought:\s*(.*?)\s*action:", output_text, re.DOTALL)
        if match:
            thought = match.group(1).strip()
            return thought
        else:
            logger.info(f"extract thought failure: {output_text}")
            return ""
        
    def Is_action_success(self, preds, gt, image_size, action_type, dataset_name, bbox):
        if action_type in [1, 2, 13]:
            if dataset_name in ['AndroidControl', 'AITZ', 'OmniAct', "GUIAct"]:
                x1, y1 = self.extract_coordinates(preds)
                x1, y1 = x1/image_size[0]*1000, y1/image_size[1]*1000
                x2, y2 = self.extract_coordinates(gt)
                x2, y2 = x2/image_size[0]*1000, y2/image_size[1]*1000

                distance = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
                return distance <= 140
            else:
                x, y = self.extract_coordinates(preds)
                x, y = float(x), float(y)
                # x, y = float(x) / image_size[0]*1000, float(y)/image_size[1]*1000
                coor_x1, coor_y1, coor_x2, coor_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                if x < coor_x1 or x > coor_x2 or y < coor_y1 or y > coor_y2:
                    return False
                else:
                    return True

        elif action_type == 4:
            scroll1 = self.extract_text(preds, 'to')
            scroll2 = self.extract_text(gt, 'to')
            if "7B" in self.policy_name:
                map_direction = {"left": "right", "right": "left", "up": "down", "down": "up"}
                scroll2 = map_direction[scroll2]
            return scroll1 == scroll2

        elif action_type == 3:
            text1 = self.extract_text(preds, 'TYPE')
            text2 = self.extract_text(gt, 'TYPE')
            return text1 == text2

        elif action_type in [5, 6, 14]:
            key1 = self.extract_text(preds, 'PRESS')
            key2 = self.extract_text(gt, 'PRESS')
            return key1 == key2
        
        elif action_type == 7:
            key1 = self.extract_text(preds, 'OPEN_APP')
            key2 = self.extract_text(gt, 'OPEN_APP')

        return True
    
    def get_action_type(self, pred):
        try:
            keys = list(pred.keys())
            if not keys:
                return 0
            if pred['STATUS'] == 'finished':
                return 9
            if pred['STATUS'] == 'impossible':
                return 11
            action_type = keys[1]

            if action_type == 'OPEN_APP':
                return 7
            elif action_type == 'POINT':
                if 'to' in keys:
                    return 4
                duration = pred.get('duration', 0)
                if duration == 1000:
                    return 2
                elif duration == 2000:
                    return 13
                else:
                    return 1
            elif action_type == 'TYPE':
                return 3
            elif action_type == 'PRESS':
                press_value = pred.get('PRESS', '')
                if press_value == 'HOME':
                    return 6
                elif press_value == 'BACK':
                    return 5
                elif press_value == 'ENTER':
                    return 10
                elif press_value == 'APPSELECT':
                    return 12
                elif press_value == 'SPACE':
                    return 14
            elif action_type == 'duration' and pred['duration'] == 200:
                return 8
            return 0 
        except Exception as e:
            logger.info(f"[get_action_type] Failed to extract: {e}")
            return 0

    


    def mapping_actions(self, action_str):
        result = {"STATUS": "continue"}
        try:
            action_start = action_str.find("action: ")
            if action_start == -1:
                raise ValueError("Cannot find action information")

            action_json_str = action_str[action_start + len("action: "):].strip()
            action_json_str = re.sub(r"([,{]\s*)['\"]([^'\"]+?)\"?\s*:", r'\1"\2":', action_json_str)
            action_json_str = re.sub(
                r':\s*\'([^\']*)\'',
                lambda m: ':"{}"'.format(m.group(1).replace('"', '\\"')),
                action_json_str
            )
            action_dict = demjson3.decode(action_json_str)
                
            action_type = action_dict.get("action_type")

            if action_type == "type":
                result["TYPE"] = action_dict.get("text", "")
            elif action_type == "click":
                result["POINT"] = [action_dict.get("x", 0), action_dict.get("y", 0)]
            elif action_type == 'moveto':
                result["POINT"] = [action_dict.get("x", 0), action_dict.get("y", 0)]
                result['duration'] = 2000
            elif action_type == "navigate_home":
                result["PRESS"] = "HOME"
            elif action_type == "navigate_back":
                result["PRESS"] = "BACK"
            
            elif action_type == 'navigate_space':
                result['PRESS'] = 'SPACE'
            
            elif action_type == 'enter':
                result["PRESS"] = "ENTER"
            
            elif action_type == 'navigate_appselect':
                result['PRESS'] = 'APPSELECT'

            elif action_type == "scroll":
                result["POINT"] = [500, 500]  # set default start point
                direction = action_dict.get("direction", "down").strip().lower()
                result["to"] = direction
            elif action_type == "open_app":
                result["OPEN_APP"] = action_dict.get("app_name", "")

            elif action_type == "wait":
                result["duration"] = 200

            elif action_type == "dismiss":
                result["POINT"] = [action_dict.get("x", 0), action_dict.get("y", 0)]
            
            elif action_type == 'stop':
                result['STATUS'] = 'finished'
            
            elif action_type == 'impossible':
                result['STATUS'] = 'impossible'

            elif action_type == "long_press":
                result["POINT"] = [action_dict.get("x", 0), action_dict.get("y", 0)]
                result["duration"] = 1000  # set default duration
                
            else:
                logger.info(f"Error, invalid action: {action_dict}")

        except json.JSONDecodeError:
            logger.info("Cannot parse action information as JSON")

        return result
    

class GUI_ODYSSEY_PREPROCESS(Agent_CPM_RES_PREPROCESS):
    def __init__(self) -> None:
        super().__init__()
    
    def extract_action(self, content):
        return self.mapping_actions(content)
    

    def extract_thought(self, output_text):
        return output_text
    
    def extract_coordinates(self, content):
        try:
            x, y = content['POINT']
        except:
            logger.info("extract coordinates failure")
            return (0, 0)
        return (x, y)
    
    def extract_text(self, content, key):
        try:
            return content[key]
        except:
            logger.info("extract text failure")
    
    def get_action_type(self, pred):
        try:
            keys = list(pred.keys())
            if not keys:
                return 0
            if pred['STATUS'] == 'finish':
                return 9
            if pred['STATUS'] == 'impossible':
                return 12
            action_type = keys[0]

            if action_type == 'OPEN_APP':
                return 7
            elif action_type == 'POINT':
                if 'to' in keys:
                    return 4
                duration = pred.get('duration', 0)
                return 2 if duration == 1000 else 1
            elif action_type == 'TYPE':
                return 3
            elif action_type == 'PRESS':
                press_value = pred.get('PRESS', '')
                if press_value == 'HOME':
                    return 6
                elif press_value == 'BACK':
                    return 5
                elif press_value == 'ENTER':
                    return 10
                elif press_value == 'APPSELECT':
                    return 12
                elif press_value == 'SPACE':
                    return 14
            elif action_type == 'duration' and pred['duration'] == 200:
                return 8
            return 0 
        except Exception as e:
            logger.info(f"[get_action_type] Failed to extract: {e, pred}")
            return 0
    

    def mapping_actions(self, action:str) -> dict:
        if action.startswith("CLICK"):
            pattern = r"CLICK: \((\d+),\s*(\d+)\)"
            match = re.match(pattern, action)
            x = int(match.group(1))
            y = int(match.group(2))

            return {
                "POINT": [x, y],
                "duration": 200,
                "STATUS": "continue"
            }
        elif action.startswith("PRESS"):
            sub_action = action.split("_")[-1]
            return {
                "PRESS": sub_action if sub_action != 'RECENT' else 'APPSELECT',
                "duration": 200,
                "STATUS": "continue"
            }
        elif action.startswith("ENTER"):
            return {
                "PRESS": "ENTER",
                "duration": 200,
                "STATUS": "continue"
            }
        elif action.startswith("OPENAPP"):
            text = action.split(":")[-1].strip()
            return {
                "OPEN_APP": text,
                "duration": 200,
                "STATUS": "continue"
            }
        elif action.startswith("TYPE"):
            text = action.split(":")[-1].strip()
            return {
                "TYPE": text,
                "duration": 200,
                "STATUS": "continue"
            }
        elif action == "COMPLETE":
            return {
                "STATUS": "finish"
            }
        elif action == "IMPOSSIBLE":
            return {
                "STATUS": "impossible"
            }
        elif action.startswith("SCROLL"):
            direction = action.split(":")[-1].strip()
            return {
                "POINT": [500, 500],
                "to": direction.lower(),
                "duration": 200,
                "STATUS": "continue"
            }
        elif action.startswith("LONG_PRESS"):
            pattern = r"LONG_PRESS: \((\d+),\s*(\d+)\)"
            match = re.match(pattern, action)
            x = int(match.group(1))
            y = int(match.group(2))
            return {
                "POINT": [x, y],
                "duration": 1000,
                "STATUS": "continue"
            }
        elif action.startswith("WAIT"):
            return {
                "duration": 200,
                "STATUS": "continue"
            }
        else:
            print(action)
            raise NotImplementedError
        
    def Is_action_success(self, preds, gt, image_size, action_type, dataset_name, bbox):
        if action_type in [1, 2]:
            if dataset_name in ['AndroidControl', 'AITZ', 'OmniAct']:
                x1, y1 = self.extract_coordinates(preds)
                x2, y2 = self.extract_coordinates(gt)
                distance = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
                return distance <= 140
            else:
                x, y = self.extract_coordinates(preds)
                coor_x1, coor_y1, coor_x2, coor_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                if x < coor_x1 or x > coor_x2 or y < coor_y1 or y > coor_y2:
                    return False
                else:
                    return True

        elif action_type == 4:
            scroll1 = self.extract_text(preds, 'to')
            scroll2 = self.extract_text(gt, 'to')
            return scroll1 == scroll2

        elif action_type == 3:
            text1 = self.extract_text(preds, 'TYPE')
            text2 = self.extract_text(gt, 'TYPE')
            return text1 == text2

        elif action_type in [5, 6, 10, 11]:
            key1 = self.extract_text(preds, 'PRESS')
            key2 = self.extract_text(gt, 'PRESS')
            return key1 == key2
        
        elif action_type == 7:
            key1 = self.extract_text(preds, 'OPEN_APP')
            key2 = self.extract_text(gt, 'OPEN_APP')

        elif action_type == 9 or action_type == 12:
            status1 = self.extract_text(preds, 'STATUS')
            status2 = self.extract_text(gt, 'STATUS')
            return status1 == status2

        elif action_type == 8 and action_type == 10:
            return True
        return True
  
class GUI_OWL_PREPROCESS(GUI_R1_RES_PREPROCESS):
    def __init__(self) -> None:
        super().__init__()

    def extract_action(self, content):
        pattern = r"<tool_call>(.*?)</tool_call>"
        match = re.search(pattern, content, re.S)
        if match:
            tool_content = match.group(1).strip()  # 改成新变量
            try:
                action_str = json.loads(tool_content)['arguments']
                return action_str
            except Exception as e:
                logger.warning(f"json loads failure: {e}, raw={tool_content}")
                return {"action": "error"}
        return {"action": "error"}
    
    def extract_thought(self, output_text):
        pattern = r"<thinking>(.*?)</thinking>"
        match = re.search(pattern, output_text, re.S)  
        if match:
            content = match.group(1).strip()
            return content
        return ""

    def get_action_type(self, action):
        action_type = action['action']
        if action_type == 'click':
            return 1
        elif action_type == 'long_press':
            return 2
        elif action_type == 'type':
            return 3
        elif action_type == 'swipe':
            return 4
        elif action_type == 'system_button':
            if action['button'] == 'Home':
                return 6
            elif action['button'] == 'Back':
                return 5
            elif action['button'] == 'Enter':
                return 10
        elif action_type == 'open':
            return 7
        elif action_type == 'wait':
            return 8
        elif action_type == 'terminate':
            return 9
        else:
            return 0
    
    def extract_text(self, action):
        return action['text']
    
    def get_direction(self, action):
        x1, y1 = action['coordinate']
        x2, y2 = action['coordinate2']
        dx = x2 - x1
        dy = y2 - y1
        if abs(dy) > abs(dx):
            if dy < 0:
                direction = "up"
            else:
                direction = "down"
        else:
            if dx < 0:
                direction = "left"
            else:
                direction = "right"
        return direction
    
    def get_direction2(self, action):
        x1, y1 = action['coordinate']
        x2, y2 = action['coordinate2']
        dx = x2 - x1
        dy = y2 - y1
        if abs(dy) > abs(dx):
            if dy < 0:
                direction = "left"
            else:
                direction = "right"
        else:
            if dx < 0:
                direction = "up"
            else:
                direction = "down"
        return direction
    
    def Is_action_success(self, preds, gt, image_size, check_action, dataset_name, bbox):
        action1_type = check_action
        if action1_type == 1 or action1_type == 2 or action1_type == 13:
            
            # if dataset_name in ['AndroidControl', 'AITZ', 'OmniAct', 'GUIAct']:
            if bbox=="":
                x1, y1 = preds['coordinate']
                x2, y2 = gt['coordinate']
                x1, y1 = x1 / image_size[0] * 1000, y1 / image_size[1] * 1000
                # x2, y2 = x2 / image_size[0] * 1000, y2 / image_size[1] * 1000
                dx = abs(x1 - x2) 
                dy = abs(y1 - y2)

                distance = math.sqrt(dx ** 2 + dy ** 2)
                if distance > 140:
                    return False
                else:
                    return True
            else:
                print(bbox)
                x, y = preds['coordinate']
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                coor_x1, coor_y1, coor_x2, coor_y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                if x < coor_x1 or x > coor_x2 or y < coor_y1 or y > coor_y2:
                    return False
                else:
                    return True
        
        elif action1_type == 3 or action1_type==7:
            text1 = self.extract_text(preds)
            text2 = self.extract_text(gt)
            return text1 == text2
        elif action1_type == 4:
            text1 = self.get_direction(preds)
            text2 = self.get_direction(gt)
            return text1 == text2
        return True
      
    def _res_statistics(self, args, allPredResults):
        return super()._res_statistics(args, allPredResults)


